import sympy
import re
from sympy.codegen.ast import Expr
POW_RE = re.compile(r'pow\((?P<base>[^,]*)\s*,\s*(?P<exp>[^\)]*)\)')
def expandpow(m):
  base = m['base']
  exp = int(m['exp'])
  return '*'.join(['(' + base + ')'] * exp)
def expand(s):
  return POW_RE.sub(expandpow, s)
def rdot(i0, i1, i2, i3):
  if (i0 < i2 or i0 == i2 and i1 < i3):
    return sympy.symbols('r%d%d%d%d' % (i0, i1, i2, i3))
  else:
    return sympy.symbols('r%d%d%d%d' % (i2, i3, i0, i1))
def generate_shake_series(natoms, constraints, name='rattle'):
  args = []

  args.append('celldata_t *cell')
  args.append('int i')
  #args.append('real dtfsq, real tol, int maxiter')
  indent = ' ' * len('void %s(' % name)
  print('void %s(%s){' % (name, (', ').join(args)))
  print('  int ig0 = i;')
  for i in range(1, natoms):
    print('  int ig%d = cell->bonded_id[cell->first_bonded[i] + cell->shake[i].idx[%d]];' % (i, i - 1))
  for i in range(natoms):
    print('  real rmass%d = cell->rmass[ig%d];' % (i, i))

  r = [[sympy.symbols('r%d%d' % (i, j)) for j in range(natoms)] for i in range(natoms)]
  # m = [sympy.symbols('m%s' % (i)) for i in range(natoms)]
  rm = [sympy.symbols('rmass%s' % (i)) for i in range(natoms)]
  l = [sympy.symbols('l%s' % (i)) for i in range(len(constraints))]
  A = sympy.Matrix([[0 for i in range(len(constraints))] for j in range(len(constraints))])
  dt = sympy.symbols('\\varDelta{t}')
  rs = map(lambda k: 'r%d%d, vp%d%d' % (k[0], k[1], k[0], k[1]), constraints)
  
  print('  vec<real> ' + ', '.join(rs) + ';')
  for ik, [k1, k2] in enumerate(constraints):
    print('  vecsubv(r%d%d, cell->x[ig%d], cell->x[ig%d]);' % (k1, k2, k2, k1))
  
  for ik, [k1, k2] in enumerate(constraints):
    print('  vecsubv(vp%d%d, cell->shake_vp[ig%d], cell->shake_vp[ig%d]);' % (k1, k2, k2, k1))

  for ik, [k1, k2] in enumerate(constraints):
    for ikk, [kk1, kk2] in enumerate(constraints):
      if ikk < ik:
        print('  real a%d%d = a%d%d;' % (ik, ikk, ikk, ik))
        continue
      # elem = 0
      delta11 = int(k1 == kk1)
      delta12 = int(k1 == kk2)
      delta21 = int(k2 == kk1)
      delta22 = int(k2 == kk2)
      # elem = ((delta11 - delta12) / m[k1] + (delta22 - delta21) / m[k2]) * r[kk1][kk2] * ruc[k1][k2]
      expr = ['  real a%d%d = vecdot(r%d%d, r%d%d)*(' % (ik, ikk, k1, k2, kk1, kk2)]
      if delta11 - delta12 == -1:
        expr.append('-rmass%d' % k1)
      elif delta11 - delta12 == 1:
        expr.append('+rmass%d' % k1)
      if delta22 - delta21 == -1:
        expr.append('-rmass%d' % k2)
      elif delta22 - delta21 == 1:
        expr.append('+rmass%d' % k2)
      expr.append(');')
      print(''.join(expr))
  mat = sympy.Matrix([[sympy.symbols('a%d%d' % (min(i,j), max(i,j))) for j in range(len(constraints))] for i in range(len(constraints))])
  #print(mat)
  matdet = mat.det()
  sympy.init_printing()
  print('  real invmatdet = 1/(' + expand(sympy.printing.ccode(matdet)) + ');')
  matinvdet = sympy.simplify(mat.inv() * matdet)
  for i in range(mat.shape[0]):
    for j in range(mat.shape[1]):
      if (i <= j):
        print('  real a%d%dinv = (' % (i, j) + expand(sympy.printing.ccode(matinvdet[i, j])) + ') * invmatdet;')
  matinv = sympy.Matrix([[sympy.symbols('a%d%dinv' % (min(i,j), max(i,j))) for j in range(len(constraints))] for i in range(len(constraints))])
  for ik, [k1, k2] in enumerate(constraints):
    print('  real c%d = -vecdot(vp%d%d, r%d%d);' % (ik, k1, k2, k1, k2))
  c = sympy.Matrix([sympy.symbols('c%d' % i) for i in range(len(constraints))])
  lvec = matinv * c
  #print(lvec)
  for ik in range(len(constraints)):
    print('  real l%d = %s;' % (ik, expand(sympy.printing.ccode(lvec[ik]))))

  for i in range(natoms):
    for ik, [k1, k2] in enumerate(constraints):
      if i == k1:
        sign = '-'
      elif i == k2:
        sign = '+'
      else:
        sign = None
      if sign is not None:
        print('  vecscaleaddv(cell->v[ig%d], cell->v[ig%d], r%d%d, 1, %sl%d * rmass%d);' % (i, i, k1, k2, sign, ik, i))
  print('}')
generate_shake_series(2, [[0, 1]], 'rattle2')
generate_shake_series(3, [[0, 1], [0, 2]], 'rattle3')
generate_shake_series(3, [[0, 1], [0, 2], [1, 2]], 'rattle3angle')
generate_shake_series(4, [[0, 1], [0, 2], [0, 3]], 'rattle4')